import math
from util import display

def train(dataSet,nu,rate):
	nu.reset()
	n_sample = len(dataSet.samples)
	# iteration
	for n in range(100):
		grad_w1 = 0.0
		grad_w2 = 0.0
		grad_b = 0.0
		costs = 0.0
		grad = 0.0
		#calculate gradient average
		for i,s in enumerate(dataSet.samples):
			predict_y = nu.active(s[0],s[1])
			if(s[2]): # label = 1
				grad = predict_y - 1
				if(predict_y>4e-44): # log(0) : math domain error
					costs -= math.log(predict_y)
				else:
					costs -= -99.9
			else: # label = 0
				grad = predict_y - 0
				if(1-predict_y > 4e-44): # log(0) : math domain error
					costs -= math.log(1-predict_y)
				else:
					costs -= -99.9
			grad_w1 += s[0] * grad # for weight 1
			grad_w2 += s[1] * grad # for weight 2
			grad_b += 1.0  * grad # for bias, weight 0
		grad_w1 /= n_sample # average
		grad_w2 /= n_sample # average
		grad_b /= n_sample # average
		costs /= n_sample
		# update parameters
		nu.w1 -= rate * grad_w1
		nu.w2 -= rate * grad_w2
		nu.b  -= rate * grad_b
		# show cost
		display.fill(0)
		display.text("Train... E:", 10, 0, 1)
		display.text(str(n), 90, 0, 1)
		display.text("Costs=", 10, 10, 1)
		display.text(str(costs), 60, 10, 1)
		# print(costs,nu.w1,nu.w2,nu.b)
		display.text("w1=", 10, 30, 1)
		display.text(str(nu.w1), 50, 30, 1)
		display.text("w2=", 10, 40, 1)
		display.text(str(nu.w2), 50, 40, 1)
		display.text(" b=", 10, 50, 1)
		display.text(str(nu.b), 50, 50, 1)
		display.show()
